{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2.0教程-结构化数据分类\n",
    "\n",
    "tensorflow2教程知乎专栏：https://zhuanlan.zhihu.com/c_1091021863043624960\n",
    "\n",
    "本教程展示了如何对结构化数据进行分类（例如CSV中的表格数据）。我们使用Keras定义模型，并将csv中各列的特征转化为训练的输入。 本教程包含一下功能代码：\n",
    "\n",
    "- 使用Pandas加载CSV文件。\n",
    "- 构建一个输入的pipeline，使用tf.data批处理和打乱数据。\n",
    "- 从CSV中的列映射到用于训练模型的输入要素。\n",
    "- 使用Keras构建，训练和评估模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-alpha0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import feature_column\n",
    "from tensorflow.keras import layers\n",
    "from sklearn.model_selection import train_test_split\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.数据集\n",
    "我们将使用克利夫兰诊所心脏病基金会提供的一个小数据集。 CSV中有几百行。 每行描述一个患者，每列描述一个属性。 我们将使用此信息来预测患者是否患有心脏病，该疾病在该数据集中是二元分类任务。\n",
    "\n",
    ">Column| Description| Feature Type | Data Type\n",
    ">------------|--------------------|----------------------|-----------------\n",
    ">Age | Age in years | Numerical | integer\n",
    ">Sex | (1 = male; 0 = female) | Categorical | integer\n",
    ">CP | Chest pain type (0, 1, 2, 3, 4) | Categorical | integer\n",
    ">Trestbpd | Resting blood pressure (in mm Hg on admission to the hospital) | Numerical | integer\n",
    ">Chol | Serum cholestoral in mg/dl | Numerical | integer\n",
    ">FBS | (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false) | Categorical | integer\n",
    ">RestECG | Resting electrocardiographic results (0, 1, 2) | Categorical | integer\n",
    ">Thalach | Maximum heart rate achieved | Numerical | integer\n",
    ">Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical | integer\n",
    ">Oldpeak | ST depression induced by exercise relative to rest | Numerical | integer\n",
    ">Slope | The slope of the peak exercise ST segment | Numerical | float\n",
    ">CA | Number of major vessels (0-3) colored by flourosopy | Numerical | integer\n",
    ">Thal | 3 = normal; 6 = fixed defect; 7 = reversable defect | Categorical | string\n",
    ">Target | Diagnosis of heart disease (1 = true; 0 = false) | Classification | integer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.准备数据\n",
    "使用pandas读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>cp</th>\n",
       "      <th>trestbps</th>\n",
       "      <th>chol</th>\n",
       "      <th>fbs</th>\n",
       "      <th>restecg</th>\n",
       "      <th>thalach</th>\n",
       "      <th>exang</th>\n",
       "      <th>oldpeak</th>\n",
       "      <th>slope</th>\n",
       "      <th>ca</th>\n",
       "      <th>thal</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>63</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>145</td>\n",
       "      <td>233</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>150</td>\n",
       "      <td>0</td>\n",
       "      <td>2.3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>fixed</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>67</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>160</td>\n",
       "      <td>286</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>108</td>\n",
       "      <td>1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>normal</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>67</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>120</td>\n",
       "      <td>229</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>129</td>\n",
       "      <td>1</td>\n",
       "      <td>2.6</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>reversible</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>37</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>130</td>\n",
       "      <td>250</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>187</td>\n",
       "      <td>0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>normal</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>41</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>130</td>\n",
       "      <td>204</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>172</td>\n",
       "      <td>0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>normal</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  sex  cp  trestbps  chol  fbs  restecg  thalach  exang  oldpeak  slope  \\\n",
       "0   63    1   1       145   233    1        2      150      0      2.3      3   \n",
       "1   67    1   4       160   286    0        2      108      1      1.5      2   \n",
       "2   67    1   4       120   229    0        2      129      1      2.6      2   \n",
       "3   37    1   3       130   250    0        0      187      0      3.5      3   \n",
       "4   41    0   2       130   204    0        2      172      0      1.4      1   \n",
       "\n",
       "   ca        thal  target  \n",
       "0   0       fixed       0  \n",
       "1   3      normal       1  \n",
       "2   2  reversible       0  \n",
       "3   0      normal       0  \n",
       "4   0      normal       0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "URL = 'https://storage.googleapis.com/applied-dl/heart.csv'\n",
    "dataframe = pd.read_csv(URL)\n",
    "dataframe.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "划分训练集验证集和测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "193 train examples\n",
      "49 validation examples\n",
      "61 test examples\n"
     ]
    }
   ],
   "source": [
    "train, test = train_test_split(dataframe, test_size=0.2)\n",
    "train, val = train_test_split(train, test_size=0.2)\n",
    "print(len(train), 'train examples')\n",
    "print(len(val), 'validation examples')\n",
    "print(len(test), 'test examples')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用tf.data构造输入pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def df_to_dataset(dataframe, shuffle=True, batch_size=32):\n",
    "    dataframe = dataframe.copy()\n",
    "    labels = dataframe.pop('target')\n",
    "    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n",
    "    if shuffle:\n",
    "        ds = ds.shuffle(buffer_size=len(dataframe))\n",
    "    ds = ds.batch(batch_size)\n",
    "    return ds\n",
    "\n",
    "batch_size = 5\n",
    "train_ds = df_to_dataset(train, batch_size=batch_size)\n",
    "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n",
    "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Every feature: ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal']\n",
      "A batch of ages: tf.Tensor([61 51 57 51 44], shape=(5,), dtype=int32)\n",
      "A batch of targets: tf.Tensor([0 0 0 1 0], shape=(5,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for feature_batch, label_batch in train_ds.take(1):\n",
    "    print('Every feature:', list(feature_batch.keys()))\n",
    "    print('A batch of ages:', feature_batch['age'])\n",
    "    print('A batch of targets:', label_batch )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.tensorflow的feature column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_batch = next(iter(train_ds))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def demo(feature_column):\n",
    "    feature_layer = layers.DenseFeatures(feature_column)\n",
    "    print(feature_layer(example_batch).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数字列\n",
    "特征列的输出成为模型的输入。 数字列是最简单的列类型。 它用于表示真正有价值的特征。 使用此列时，模型将从数据框中接收未更改的列值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0324 13:43:10.728773 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2758: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[61.]\n",
      " [51.]\n",
      " [57.]\n",
      " [51.]\n",
      " [44.]]\n"
     ]
    }
   ],
   "source": [
    "age = feature_column.numeric_column(\"age\")\n",
    "demo(age)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bucketized列（桶列）\n",
    "通常，您不希望将数字直接输入模型，而是根据数值范围将其值分成不同的类别。 考虑代表一个人年龄的原始数据。 我们可以使用bucketized列将年龄分成几个桶，而不是将年龄表示为数字列。 请注意，下面的one-hot描述了每行匹配的年龄范围。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0324 13:48:31.327955 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2902: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 1. 0.]]\n"
     ]
    }
   ],
   "source": [
    "age_buckets = feature_column.bucketized_column(age, boundaries=[\n",
    "    18, 25, 30, 35, 40, 50\n",
    "])\n",
    "demo(age_buckets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 类别列\n",
    "在该数据集中，thal表示为字符串（例如“固定”，“正常”或“可逆”）。 我们无法直接将字符串提供给模型。 相反，我们必须首先将它们映射到数值。 类别列提供了一种将字符串表示为单热矢量的方法（就像上面用年龄段看到的那样）。 类别表可以使用categorical_column_with_vocabulary_list作为列表传递，或者使用categorical_column_with_vocabulary_file从文件加载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0324 13:55:01.628555 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4307: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "W0324 13:55:01.629235 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 1.]\n",
      " [0. 1. 0.]\n",
      " [0. 0. 1.]\n",
      " [0. 0. 1.]\n",
      " [0. 1. 0.]]\n"
     ]
    }
   ],
   "source": [
    "thal = feature_column.categorical_column_with_vocabulary_list('thal', ['fixed', 'normal', 'reversible'])\n",
    "thal_one_hot = feature_column.indicator_column(thal)\n",
    "demo(thal_one_hot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 嵌入列\n",
    "假设我们不是只有几个可能的字符串，而是每个类别有数千（或更多）值。 由于多种原因，随着类别数量的增加，使用单热编码训练神经网络变得不可行。 我们可以使用嵌入列来克服此限制。 嵌入列不是将数据表示为多维度的单热矢量，而是将数据表示为低维密集向量，其中每个单元格可以包含任意数字，而不仅仅是0或1.嵌入的大小是必须训练调整的参数。\n",
    "\n",
    "注：当分类列具有许多可能的值时，最好使用嵌入列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.21029451  0.28502795  0.27186757 -0.13927     0.44176006  0.18506278\n",
      "  -0.14189719  0.2901029 ]\n",
      " [-0.02674027 -0.21359333 -0.26675928  0.6544374   0.12530805 -0.5243998\n",
      "  -0.23030454 -0.10796055]\n",
      " [ 0.21029451  0.28502795  0.27186757 -0.13927     0.44176006  0.18506278\n",
      "  -0.14189719  0.2901029 ]\n",
      " [ 0.21029451  0.28502795  0.27186757 -0.13927     0.44176006  0.18506278\n",
      "  -0.14189719  0.2901029 ]\n",
      " [-0.02674027 -0.21359333 -0.26675928  0.6544374   0.12530805 -0.5243998\n",
      "  -0.23030454 -0.10796055]]\n"
     ]
    }
   ],
   "source": [
    "thal_embedding = feature_column.embedding_column(thal, dimension=8)\n",
    "demo(thal_embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 哈希特征列\n",
    "表示具有大量值的分类列的另一种方法是使用categorical_column_with_hash_bucket。 此功能列计算输入的哈希值，然后选择一个hash_bucket_size存储桶来编码字符串。 使用此列时，您不需要提供词汇表，并且可以选择使hash_buckets的数量远远小于实际类别的数量以节省空间。\n",
    "\n",
    "注：该技术的一个重要缺点是可能存在冲突，其中不同的字符串被映射到同一个桶。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0324 14:03:23.451644 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: HashedCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "thal_hashed = feature_column.categorical_column_with_hash_bucket('thal', hash_bucket_size=1000)\n",
    "demo(feature_column.indicator_column(thal_hashed))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 交叉功能列\n",
    "将特征组合成单个特征（更好地称为特征交叉），使模型能够为每个特征组合学习单独的权重。 在这里，我们将创建一个与age和thal交叉的新功能。 请注意，crossed_column不会构建所有可能组合的完整表（可能非常大）。 相反，它由hashed_column支持，因此您可以选择表的大小。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0324 14:09:05.265740 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: CrossedColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)\n",
    "demo(feature_column.indicator_column(crossed_feature))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.选择使用feature column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_columns = []\n",
    "\n",
    "# numeric cols\n",
    "for header in ['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'slope', 'ca']:\n",
    "    feature_columns.append(feature_column.numeric_column(header))\n",
    "\n",
    "# bucketized cols\n",
    "age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
    "feature_columns.append(age_buckets)\n",
    "\n",
    "# indicator cols\n",
    "thal = feature_column.categorical_column_with_vocabulary_list(\n",
    "      'thal', ['fixed', 'normal', 'reversible'])\n",
    "thal_one_hot = feature_column.indicator_column(thal)\n",
    "feature_columns.append(thal_one_hot)\n",
    "\n",
    "# embedding cols\n",
    "thal_embedding = feature_column.embedding_column(thal, dimension=8)\n",
    "feature_columns.append(thal_embedding)\n",
    "\n",
    "# crossed cols\n",
    "crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)\n",
    "crossed_feature = feature_column.indicator_column(crossed_feature)\n",
    "feature_columns.append(crossed_feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建特征层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_layer = tf.keras.layers.DenseFeatures(feature_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_ds = df_to_dataset(train, batch_size=batch_size)\n",
    "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n",
    "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.构建模型并训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "7/7 [==============================] - 1s 133ms/step - loss: 1.1864 - accuracy: 0.6357 - val_loss: 0.6905 - val_accuracy: 0.5714\n",
      "Epoch 2/5\n",
      "7/7 [==============================] - 0s 24ms/step - loss: 0.9603 - accuracy: 0.6804 - val_loss: 0.4047 - val_accuracy: 0.8163\n",
      "Epoch 3/5\n",
      "7/7 [==============================] - 0s 24ms/step - loss: 0.5744 - accuracy: 0.7389 - val_loss: 0.6673 - val_accuracy: 0.7755\n",
      "Epoch 4/5\n",
      "7/7 [==============================] - 0s 24ms/step - loss: 0.4890 - accuracy: 0.8092 - val_loss: 0.6298 - val_accuracy: 0.6122\n",
      "Epoch 5/5\n",
      "7/7 [==============================] - 0s 24ms/step - loss: 0.5618 - accuracy: 0.6795 - val_loss: 0.3861 - val_accuracy: 0.8367\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fcb23e9d198>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    feature_layer,\n",
    "    layers.Dense(128, activation='relu'),\n",
    "    layers.Dense(128, activation='relu'),\n",
    "    layers.Dense(1, activation='sigmoid')\n",
    "])\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "             loss='binary_crossentropy',\n",
    "             metrics=['accuracy'])\n",
    "model.fit(train_ds, validation_data=val_ds,epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 [==============================] - 0s 16ms/step - loss: 0.8278 - accuracy: 0.6066\n",
      "Accuracy 0.60655737\n"
     ]
    }
   ],
   "source": [
    "loss, accuracy = model.evaluate(test_ds)\n",
    "print(\"Accuracy\", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
